{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71a329f0",
   "metadata": {},
   "source": [
    "# ReRanking model deployment guide\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "- SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fa3208",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ac353",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.djl_inference.model import DJLModel, image_uris\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e58cf33",
   "metadata": {},
   "source": [
    "## Step 2: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d955679",
   "metadata": {},
   "source": [
    "### Getting the container image URI\n",
    "\n",
    "Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a174b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose a specific version of LMI image directly:\n",
    "# image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124\"\n",
    "# image_uri = image_uris.retrieve(framework=\"djl-lmi\", region=\"us-west-2\", version=\"latest\")\n",
    "image_uri = image_uris.retrieve(framework=\"djl-lmi\", region=\"us-west-2\", version=\"0.31.0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11601839",
   "metadata": {},
   "source": [
    "### Create SageMaker model\n",
    "\n",
    "You can deploy model from Huggingface hub or DJL model zoo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b1e5ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_id = \"djl://ai.djl.huggingface.rust/BAAI/bge-reranker-v2-m3\" # use DJL model zoo model\n",
    "# model_id = \"s3://YOUR_BUCKET\" # download model from your s3 bucket\n",
    "model_id = \"BAAI/bge-reranker-v2-m3\" # model will be download form Huggingface hub\n",
    "\n",
    "env = {\n",
    "    \"SERVING_BATCH_SIZE\": \"32\",   # enable dynamic batch with max batch size 32\n",
    "    \"SERVING_MIN_WORKERS\": \"1\",   # make sure min and max Workers are equals when deploy model on GPU\n",
    "    \"SERVING_MAX_WORKERS\": \"1\",\n",
    "    \"ARGS_RERANKING\": \"true\",     # Use Reranking\n",
    "}\n",
    "\n",
    "model = DJLModel(\n",
    "    model_id=model_id,\n",
    "    task=\"text-embedding\",\n",
    "    #engine=\"Rust\",          # explicitly choose Rust engine\n",
    "    #image_uri=image_uri,      # choose a specific version of LMI DLC image\n",
    "    env=env,\n",
    "    role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f39f6",
   "metadata": {},
   "source": [
    "### Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e0e61cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type = \"ml.g4dn.2xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-text-embedding\")\n",
    "\n",
    "predictor = model.deploy(initial_instance_count=1,\n",
    "    instance_type=instance_type,\n",
    "    endpoint_name=endpoint_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb63ee65",
   "metadata": {},
   "source": [
    "## Step 3: Test and benchmark the inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79786708",
   "metadata": {},
   "source": [
    "Let's try to run with an input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62714a82-45f8-4900-88bd-5137b85e8193",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"key\": \"What is Deep Learning?\", \"value\": \"Deep Learning is not\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2a6fb98-39a1-4a2a-9d30-9d23efeae44c",
   "metadata": {},
   "source": [
    "**Note:** The following request format requires LMI 0.30.0+"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bcef095",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"text\": \"What is Deep Learning?\", \"text_pair\": \"Deep Learning is not\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2170adb9-3d9e-4605-8787-bfe1b0f3ecfa",
   "metadata": {},
   "source": [
    "You can make requests with client side batch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19ec54e1-c1f3-4ce5-84f9-be30bcd5db9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    [{\"text\": \"What is Deep Learning?\", \"text_pair\": \"Deep Learning is not\"}, {\"text\": \"Hello\", \"text_pair\": \"Hi\"}]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cec7b0f3-dd51-42fe-9f8f-4c322e8ba9f3",
   "metadata": {},
   "source": [
    "We also support HuggingFace TEI style batch input format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdee60c9-a650-4a2e-9629-3e05edfc2138",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"query\":\"What is Deep Learning?\", \"texts\": [\"Deep Learning is not\", \"Deep learning is\"]}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cd9042",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d674b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "session.delete_endpoint(endpoint_name)\n",
    "session.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
